import numpy as np
from scipy.optimize import fmin_tnc

class RegularizedMNLRegression:
    
    def compute_prob(self, theta, x, y):
        means = np.dot(x, theta)
        u = np.exp(means)
        u_ones = np.column_stack((u, np.ones(u.shape[0])))
        logSumExp = u_ones.sum(axis=1)
        prob = u_ones / logSumExp[:, None]
        return prob
        
    def cost_function(self, theta, *args):
        x, y, lam = args[0], args[1], args[2]
        m = x.shape[0]
        prob = self.compute_prob(theta, x, y)
        return - (1/m) * np.sum(np.multiply(y, np.log(prob))) + (1/m) * lam * np.linalg.norm(theta)
    
    def gradient(self, theta, *args):
        x, y, lam = args[0], args[1], args[2]
        m = x.shape[0]
        prob = self.compute_prob(theta, x, y)
        eps = (prob - y)[:, :-1]
        grad = (1/m) * np.tensordot(eps, x, axes=([1, 0], [1, 0])) + (1/m) * lam * theta
        return grad

    def fit(self, theta, *args):
        opt_weights = fmin_tnc(func=self.cost_function, x0=theta, fprime=self.gradient, args=args, disp=0)
        self.w = opt_weights[0]
        return self